{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Datawhale智慧海洋建设-Task4模型建立" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "此部分为智慧海洋建设竞赛的模型建立模块。在该模块中主要介绍了如何进行模型建立并对模型调优。" ] }, { "cell_type": "markdown", "metadata": { "hide_input": true }, "source": [ "## 学习目标\n", "1. 学习如何选择合适的模型以及如何通过模型来进行特征选择\n", "2. 掌握随机森林、lightGBM、Xgboost模型的使用。\n", "3. 掌握贝叶斯优化方法的具体使用" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 内容介绍\n", "1. 模型训练与预测\n", " - 随机森林\n", " - lightGBM模型\n", " - Xgboost模型\n", "2. 交叉验证\n", "3. 模型调参\n", "4. 智慧海洋数据集模型代码示例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型训练与预测" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "模型训练与预测的主要步骤为: \n", "(1):导入需要的工具库 \n", "(2):对数据预处理,包括导入数据集、处理数据等操作,具体为缺失值处理、连续特征归一化、类别特征转换等 \n", "(3):训练模型。选择合适的机器学习模型,利用训练集对模型进行训练,达到最佳拟合效果。 \n", "(4):预测结果。将待预测的数据输入到训练好的模型中,得到预测的结果。\n", "\n", "下面进行几种常用的分类算法进行介绍" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 随机森林分类\n", "[随机森林参数介绍](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier) \n", "随机森林是通过集成学习的思想将多棵树集成的一种算法,基本单元是决策树,而它的本质属于机器学习的一个分支——集成学习。\n", "随机森林模型的主要优点是:在当前算法中,具有较好的准确率;能够有效地运行在大数据集上;能够处理具有高维特征的输入样本,而且不需要降维;能够评估各个特征在分类问题上的重要性;在生成过程中,能够获取到内部生成误差的一种无偏估计;对于缺省值问题也能够获得很好的结果。 \n", "\n", "使用sklearn调用随机森林分类树进行预测算法:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import f1_score\n", "#数据集导入\n", "iris=datasets.load_iris()\n", "feature=iris.feature_names\n", "X = iris.data\n", "y = iris.target\n", "#随机森林\n", "clf=RandomForestClassifier(n_estimators=200)\n", "train_X,test_X,train_y,test_y = train_test_split(X,y,test_size=0.1,random_state=5)\n", "clf.fit(train_X,train_y)\n", "test_pred=clf.predict(test_X)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n", "[0.09838896 0.01544017 0.34365936 0.5425115 ]\n" ] } ], "source": [ "#特征的重要性查看\n", "print(str(feature)+'\\n'+str(clf.feature_importances_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "采用F1 score进行模型的评价,[此为一篇csdn中对该评价方法的简单说明](https://blog.csdn.net/qq_14997473/article/details/82684300)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "随机森林-macro: 0.818181818181818\n", "随机森林-weighted: 0.8\n" ] } ], "source": [ "#F1-score 用于模型评价\n", "#如果是二分类问题则选择参数‘binary’\n", "#如果考虑类别的不平衡性,需要计算类别的加权平均,则使用‘weighted’\n", "#如果不考虑类别的不平衡性,计算宏平均,则使用‘macro’\n", "score=f1_score(test_y,test_pred,average='macro')\n", "print(\"随机森林-macro:\",score)\n", "score=f1_score(test_y,test_pred,average='weighted')\n", "print(\"随机森林-weighted:\",score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## lightGBM模型\n", "[lightGBM的学习可参见这篇文章](https://mp.weixin.qq.com/s/64xfT9WIgF3yEExpSxyshQ) \n", "[lightGBM中文文档](https://lightgbm.apachecn.org/#/)这个对超参数的讲解较为详细,建议仔细阅读" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. lightGBM过拟合处理方案:\n", "- 使用较小的 max_bin\n", "- 使用较小的 num_leaves\n", "- 使用 min_data_in_leaf 和 min_sum_hessian_in_leaf\n", "- 通过设置 bagging_fraction 和 bagging_freq 来使用 bagging\n", "- 通过设置 feature_fraction 来使用特征子抽样\n", "- 使用更大的训练数据\n", "- 使用 lambda_l1, lambda_l2 和 min_gain_to_split 来使用正则\n", "- 尝试 max_depth 来避免生成过深的树 \n", "2. lightGBM针对更快的训练速度的解决方案\n", "- 通过设置 bagging_fraction 和 bagging_freq 参数来使用 bagging 方法\n", "- 通过设置 feature_fraction 参数来使用特征的子抽样\n", "- 使用较小的 max_bin\n", "- 使用 save_binary 在未来的学习过程对数据加载进行加速\n", "- 使用并行学习, 可参考 并行学习指南\n", "3. lightGBM针对更好的准确率\n", "- 使用较大的 max_bin (学习速度可能变慢)\n", "- 使用较小的 learning_rate 和较大的 num_iterations\n", "- 使用较大的 num_leaves (可能导致过拟合)\n", "- 使用更大的训练数据\n", "- 尝试 dart" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1]\ttrain's multi_logloss: 0.975702\tvalidate's multi_logloss: 1.009\n", "[2]\ttrain's multi_logloss: 0.877457\tvalidate's multi_logloss: 0.914377\n", "[3]\ttrain's multi_logloss: 0.794798\tvalidate's multi_logloss: 0.824134\n", "[4]\ttrain's multi_logloss: 0.723326\tvalidate's multi_logloss: 0.750893\n", "[5]\ttrain's multi_logloss: 0.661667\tvalidate's multi_logloss: 0.682191\n", "[6]\ttrain's multi_logloss: 0.607721\tvalidate's multi_logloss: 0.628136\n", "[7]\ttrain's multi_logloss: 0.560519\tvalidate's multi_logloss: 0.574289\n", "[8]\ttrain's multi_logloss: 0.518687\tvalidate's multi_logloss: 0.529814\n", "[9]\ttrain's multi_logloss: 0.481561\tvalidate's multi_logloss: 0.485778\n", "[10]\ttrain's multi_logloss: 0.448635\tvalidate's multi_logloss: 0.449967\n", "[11]\ttrain's multi_logloss: 0.418864\tvalidate's multi_logloss: 0.414047\n", "[12]\ttrain's multi_logloss: 0.392319\tvalidate's multi_logloss: 0.386407\n", "[13]\ttrain's multi_logloss: 0.368389\tvalidate's multi_logloss: 0.357079\n", "[14]\ttrain's multi_logloss: 0.346782\tvalidate's multi_logloss: 0.33318\n", "[15]\ttrain's multi_logloss: 0.327196\tvalidate's multi_logloss: 0.308858\n", "[16]\ttrain's multi_logloss: 0.309539\tvalidate's multi_logloss: 0.288072\n", "[17]\ttrain's multi_logloss: 0.293482\tvalidate's multi_logloss: 0.268706\n", "[18]\ttrain's multi_logloss: 0.278991\tvalidate's multi_logloss: 0.25158\n", "[19]\ttrain's multi_logloss: 0.265753\tvalidate's multi_logloss: 0.23781\n", "[20]\ttrain's multi_logloss: 0.253744\tvalidate's multi_logloss: 0.226251\n", "[21]\ttrain's multi_logloss: 0.242663\tvalidate's multi_logloss: 0.211902\n", "[22]\ttrain's multi_logloss: 0.232649\tvalidate's multi_logloss: 0.202371\n", "[23]\ttrain's multi_logloss: 0.223439\tvalidate's multi_logloss: 0.192921\n", "[24]\ttrain's multi_logloss: 0.215001\tvalidate's multi_logloss: 0.182008\n", "[25]\ttrain's multi_logloss: 0.2072\tvalidate's multi_logloss: 0.175881\n", "[26]\ttrain's multi_logloss: 0.200111\tvalidate's multi_logloss: 0.168868\n", "[27]\ttrain's multi_logloss: 0.193543\tvalidate's multi_logloss: 0.160918\n", "[28]\ttrain's multi_logloss: 0.187559\tvalidate's multi_logloss: 0.155117\n", "[29]\ttrain's multi_logloss: 0.182121\tvalidate's multi_logloss: 0.148551\n", "[30]\ttrain's multi_logloss: 0.177063\tvalidate's multi_logloss: 0.141508\n", "[31]\ttrain's multi_logloss: 0.172155\tvalidate's multi_logloss: 0.136823\n", "[32]\ttrain's multi_logloss: 0.167851\tvalidate's multi_logloss: 0.13318\n", "[33]\ttrain's multi_logloss: 0.163832\tvalidate's multi_logloss: 0.127932\n", "[34]\ttrain's multi_logloss: 0.160045\tvalidate's multi_logloss: 0.124999\n", "[35]\ttrain's multi_logloss: 0.156511\tvalidate's multi_logloss: 0.11994\n", "[36]\ttrain's multi_logloss: 0.153185\tvalidate's multi_logloss: 0.117388\n", "[37]\ttrain's multi_logloss: 0.150086\tvalidate's multi_logloss: 0.113542\n", "[38]\ttrain's multi_logloss: 0.147138\tvalidate's multi_logloss: 0.11118\n", "[39]\ttrain's multi_logloss: 0.144376\tvalidate's multi_logloss: 0.107657\n", "[40]\ttrain's multi_logloss: 0.141792\tvalidate's multi_logloss: 0.105666\n", "[41]\ttrain's multi_logloss: 0.139327\tvalidate's multi_logloss: 0.102515\n", "[42]\ttrain's multi_logloss: 0.137023\tvalidate's multi_logloss: 0.101176\n", "[43]\ttrain's multi_logloss: 0.134844\tvalidate's multi_logloss: 0.0975092\n", "[44]\ttrain's multi_logloss: 0.132768\tvalidate's multi_logloss: 0.0948682\n", "[45]\ttrain's multi_logloss: 0.130798\tvalidate's multi_logloss: 0.0939896\n", "[46]\ttrain's multi_logloss: 0.128917\tvalidate's multi_logloss: 0.0915695\n", "[47]\ttrain's multi_logloss: 0.127132\tvalidate's multi_logloss: 0.0906398\n", "[48]\ttrain's multi_logloss: 0.12546\tvalidate's multi_logloss: 0.0892012\n", "[49]\ttrain's multi_logloss: 0.123835\tvalidate's multi_logloss: 0.0884964\n", "[50]\ttrain's multi_logloss: 0.122284\tvalidate's multi_logloss: 0.087185\n", "[51]\ttrain's multi_logloss: 0.120772\tvalidate's multi_logloss: 0.0849336\n", "[52]\ttrain's multi_logloss: 0.119346\tvalidate's multi_logloss: 0.0835437\n", "[53]\ttrain's multi_logloss: 0.11795\tvalidate's multi_logloss: 0.0829754\n", "[54]\ttrain's multi_logloss: 0.116534\tvalidate's multi_logloss: 0.0819892\n", "[55]\ttrain's multi_logloss: 0.115189\tvalidate's multi_logloss: 0.0808175\n", "[56]\ttrain's multi_logloss: 0.113915\tvalidate's multi_logloss: 0.0791856\n", "[57]\ttrain's multi_logloss: 0.112663\tvalidate's multi_logloss: 0.0778838\n", "[58]\ttrain's multi_logloss: 0.111477\tvalidate's multi_logloss: 0.0767819\n", "[59]\ttrain's multi_logloss: 0.110319\tvalidate's multi_logloss: 0.0761175\n", "[60]\ttrain's multi_logloss: 0.109189\tvalidate's multi_logloss: 0.075811\n", "[61]\ttrain's multi_logloss: 0.108108\tvalidate's multi_logloss: 0.0743217\n", "[62]\ttrain's multi_logloss: 0.107049\tvalidate's multi_logloss: 0.0730824\n", "[63]\ttrain's multi_logloss: 0.106037\tvalidate's multi_logloss: 0.0725497\n", "[64]\ttrain's multi_logloss: 0.105039\tvalidate's multi_logloss: 0.0709544\n", "[65]\ttrain's multi_logloss: 0.104078\tvalidate's multi_logloss: 0.0703405\n", "[66]\ttrain's multi_logloss: 0.103134\tvalidate's multi_logloss: 0.0701205\n", "[67]\ttrain's multi_logloss: 0.102229\tvalidate's multi_logloss: 0.0692772\n", "[68]\ttrain's multi_logloss: 0.101354\tvalidate's multi_logloss: 0.068559\n", "[69]\ttrain's multi_logloss: 0.100491\tvalidate's multi_logloss: 0.0673473\n", "[70]\ttrain's multi_logloss: 0.0995573\tvalidate's multi_logloss: 0.0674286\n", "[71]\ttrain's multi_logloss: 0.0986634\tvalidate's multi_logloss: 0.0674853\n", "[72]\ttrain's multi_logloss: 0.0978101\tvalidate's multi_logloss: 0.0672873\n", "[73]\ttrain's multi_logloss: 0.0969953\tvalidate's multi_logloss: 0.0673621\n", "[74]\ttrain's multi_logloss: 0.0962072\tvalidate's multi_logloss: 0.066834\n", "[75]\ttrain's multi_logloss: 0.0954358\tvalidate's multi_logloss: 0.06728\n", "[76]\ttrain's multi_logloss: 0.0946999\tvalidate's multi_logloss: 0.0666785\n", "[77]\ttrain's multi_logloss: 0.093984\tvalidate's multi_logloss: 0.0652261\n", "[78]\ttrain's multi_logloss: 0.093268\tvalidate's multi_logloss: 0.0653247\n", "[79]\ttrain's multi_logloss: 0.0925889\tvalidate's multi_logloss: 0.0654675\n", "[80]\ttrain's multi_logloss: 0.0919186\tvalidate's multi_logloss: 0.0649799\n", "[81]\ttrain's multi_logloss: 0.0912796\tvalidate's multi_logloss: 0.0638035\n", "[82]\ttrain's multi_logloss: 0.0906195\tvalidate's multi_logloss: 0.0638154\n", "[83]\ttrain's multi_logloss: 0.0899888\tvalidate's multi_logloss: 0.0642833\n", "[84]\ttrain's multi_logloss: 0.0893663\tvalidate's multi_logloss: 0.0636025\n", "[85]\ttrain's multi_logloss: 0.0887785\tvalidate's multi_logloss: 0.0626043\n", "[86]\ttrain's multi_logloss: 0.0881855\tvalidate's multi_logloss: 0.0623685\n", "[87]\ttrain's multi_logloss: 0.0875885\tvalidate's multi_logloss: 0.0627226\n", "[88]\ttrain's multi_logloss: 0.0870171\tvalidate's multi_logloss: 0.0624081\n", "[89]\ttrain's multi_logloss: 0.0864743\tvalidate's multi_logloss: 0.0625911\n", "[90]\ttrain's multi_logloss: 0.0859347\tvalidate's multi_logloss: 0.0620309\n", "[91]\ttrain's multi_logloss: 0.0854179\tvalidate's multi_logloss: 0.0622157\n", "[92]\ttrain's multi_logloss: 0.0849131\tvalidate's multi_logloss: 0.0617822\n", "[93]\ttrain's multi_logloss: 0.0844192\tvalidate's multi_logloss: 0.0619947\n", "[94]\ttrain's multi_logloss: 0.0839399\tvalidate's multi_logloss: 0.0614539\n", "[95]\ttrain's multi_logloss: 0.0834681\tvalidate's multi_logloss: 0.0616655\n", "[96]\ttrain's multi_logloss: 0.0830149\tvalidate's multi_logloss: 0.06134\n", "[97]\ttrain's multi_logloss: 0.0825657\tvalidate's multi_logloss: 0.0613612\n", "[98]\ttrain's multi_logloss: 0.0821295\tvalidate's multi_logloss: 0.0611025\n", "[99]\ttrain's multi_logloss: 0.0816869\tvalidate's multi_logloss: 0.0613398\n", "[100]\ttrain's multi_logloss: 0.0812595\tvalidate's multi_logloss: 0.0610704\n", "0.9777777777777777\n", "训练集 0.9717813051146384\n", "验证集 0.9734471313418682\n" ] } ], "source": [ "import lightgbm as lgb\n", "from sklearn import datasets\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", "from sklearn.metrics import roc_auc_score, accuracy_score\n", "import matplotlib.pyplot as plt\n", "\n", "# 加载数据\n", "iris = datasets.load_iris()\n", "# 划分训练集和测试集\n", "X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3)\n", "# 转换为Dataset数据格式\n", "train_data = lgb.Dataset(X_train, label=y_train)\n", "validation_data = lgb.Dataset(X_test, label=y_test)\n", "# 参数\n", "results = {}\n", "params = {\n", " 'learning_rate': 0.1,\n", " 'lambda_l1': 0.1,\n", " 'lambda_l2': 0.9,\n", " 'max_depth': 1,\n", " 'objective': 'multiclass', # 目标函数\n", " 'num_class': 3,\n", " 'verbose': -1 \n", "}\n", "# 模型训练\n", "gbm = lgb.train(params, train_data, valid_sets=(validation_data,train_data),valid_names=('validate','train'),evals_result= results)\n", "# 模型预测\n", "y_pred_test = gbm.predict(X_test)\n", "y_pred_data = gbm.predict(X_train)\n", "y_pred_data = [list(x).index(max(x)) for x in y_pred_data]\n", "y_pred_test = [list(x).index(max(x)) for x in y_pred_test]\n", "# 模型评估\n", "print(accuracy_score(y_test, y_pred_test))\n", "print('训练集',f1_score(y_train, y_pred_data,average='macro'))\n", "print('验证集',f1_score(y_test, y_pred_test,average='macro'))" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 有以下曲线可知验证集的损失是比训练集的损失要高,所以模型可以判断模型出现了过拟合\n", "lgb.plot_metric(results)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#因此可以尝试将lambda_l2设置为0.9\n", "lgb.plot_metric(results)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 绘制重要的特征\n", "lgb.plot_importance(gbm,importance_type = \"split\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## xgboost模型\n", "[XGBoost基础介绍](https://mp.weixin.qq.com/s/AAKPSIHk1iUqCeUibrORqQ) \n", "[XGBoost参数介绍](https://xgboost.readthedocs.io/en/latest/parameter.html) \n", "[XGboost参数调优方法](https://blog.csdn.net/han_xiaoyang/article/details/52665396)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F1_score: 95.56%\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEWCAYAAABliCz2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWFklEQVR4nO3de7hddX3n8feHmAolXEQIKhi5xNYqN5HBOlIMHaQq+KitomhVELQ+FlFHqVWLgK3VKiidaWdUUMcroyMWMiOD4mikVq0CcqsUbxzlJjHRoIlRkvCdP/ZCD8ecnB04++bv/Xqe82Tttfbe67MXnM9Z+7fWXjtVhSTpN982ow4gSRoOC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWvjRDktcnOW/UOaT5Fs/D13xKMgXsDmyaNvt3qurW+/icJ1XVZ+9busmT5AxgaVX96aizaPK5h69BeGpVLZr2c6/Lfj4kud8o139vTWpujS8LX0ORZKck701yW5JbkvxNkgXdsn2TfC7J6iSrknwkyc7dsg8BS4D/nWRtkr9IsizJzTOefyrJkd30GUk+keTDSX4CHL+l9W8m6xlJPtxN75WkkpyQ5KYkP07y0iT/Ick1SdYk+Ydpjz0+yb8k+a9J7kjy70n+07TlD0myPMmPknw7yYtnrHd67pcCrwee3b32q7v7nZDk+iQ/TfLdJH827TmWJbk5yauTrOxe7wnTlm+X5Owk3+vyfTHJdt2y30/ype41XZ1k2b34T60xZuFrWD4AbASWAo8GjgJO6pYFeAvwEOD3gIcCZwBU1fOB7/Ordw1v63N9TwM+AewMfGSO9ffjscDDgWcD5wBvAI4EHgUcm+QJM+77XWBX4HTgk0l26ZadD9zcvdZnAn87/Q/CjNzvBf4W+Fj32g/s7rMSOAbYETgBeGeSg6c9x4OAnYA9gBOBf0zygG7ZWcBjgP8I7AL8BXBXkj2ATwF/081/DXBBkt22YhtpzFn4GoQLu73ENUkuTLI78GTglVW1rqpWAu8EngNQVd+uqkur6hdV9UPgHcATZn/6vny5qi6sqrvoFeOs6+/TX1fVz6vqM8A64PyqWllVtwD/TO+PyN1WAudU1Yaq+hhwA3B0kocChwGv7Z7rKuA84Pmby11V6zcXpKo+VVXfqZ4vAJ8B/mDaXTYAb+rWfzGwFvjdJNsALwJeUVW3VNWmqvpSVf0C+FPg4qq6uFv3pcDlwFO2YhtpzDlGqEF4+vQDrEkOBRYCtyW5e/Y2wE3d8sXAf6FXWjt0y358HzPcNG36YVtaf59unza9fjO3F027fUvd82yI79Hbo38I8KOq+umMZYfMknuzkjyZ3juH36H3On4buHbaXVZX1cZpt3/W5dsV2Bb4zmae9mHAs5I8ddq8hcDn58qjyWHhaxhuAn4B7DqjiO72FqCAA6pqdZKnA/8wbfnMU8nW0Ss5ALqx+JlDD9MfM9f659seSTKt9JcAy4FbgV2S7DCt9JcAt0x77MzXeo/bSe4PXAC8ALioqjYkuZDesNhcVgE/B/YFrp6x7CbgQ1X14l97lH5jOKSjgauq2+gNO5ydZMck23QHau8ettmB3rDDmm4s+dQZT3E7sM+0298Etk1ydJKFwF8B978P659vi4FTkixM8ix6xyUurqqbgC8Bb0mybZID6I2xf2QLz3U7sFc3HAPwW/Re6w+Bjd3e/lH9hOqGt94HvKM7eLwgyeO6PyIfBp6a5I+6+dt2B4D33PqXr3Fl4WtYXkCvrL5Bb7jmE8CDu2VnAgcDd9A7cPjJGY99C/BX3TGB11TVHcDL6I1/30Jvj/9mtmxL659v/0rvAO8q4M3AM6tqdbfsOGAvenv7/wSc3o2Xz+Z/df+uTnJl987gFODj9F7Hc+m9e+jXa+gN/3wN+BHwd8A23R+jp9E7K+iH9Pb4T8WO+I3iB6+keZTkeHofEjts1FmkmfzrLUmNsPAlqREO6UhSI9zDl6RGjO15+DvvvHMtXbp01DG2yrp169h+++1HHWOrTFrmScsLZh4WM/dcccUVq6pqs5fEGNvC33333bn88stHHWOrrFixgmXLlo06xlaZtMyTlhfMPCxm7knyvdmWOaQjSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWpEqmrUGTZryT5La5tj/37UMbbKq/ffyNnX3m/UMbbKpGWetLxg5mEZt8xTbz16zvusWLGCZcuWzet6k1xRVYdsbpl7+JLUCAtfkhph4UvSiNx5550ceuihHHjggTzqUY/i9NNPv8fys846iySsWrVqXtY3sMJPckqS65NUkmu6ny8lOXBQ65SkSbJw4UI+97nPcfXVV3PVVVdxySWX8JWvfAWAm266iUsvvZQlS5bM2/oGuYf/MuApwOOBJ1TVAcBfA+8Z4DolaWIkYdGiRQBs2LCBDRs2kASAV73qVbztbW/75e35MJDCT/IuYB9gOfDYqvpxt+grwJ6DWKckTaJNmzZx0EEHsXjxYp74xCfy2Mc+luXLl7PHHntw4IHzOyAysNMyk0wBh1TVqmnzXgM8oqpOmuUxLwFeArDrrrs95o3nnDuQbIOy+3Zw+/pRp9g6k5Z50vKCmYdl3DLvv8dOc95n7dq1v9zDX7t2Laeddhonn3wyZ511Fm9/+9tZtGgRz3nOc3j3u9/NTjvN/XwARxxxxKynZQ7tpNUkRwAnAofNdp+qeg/dkM+SfZbWOJ1T249xOw+4H5OWedLygpmHZdwyTz1v2Zz3mXke/hVXXMGtt97K6tWrOfnkkwFYtWoVL3/5y/nqV7/Kgx70oPuUaShn6SQ5ADgPeFpVrR7GOiVp3K1Zs4Y1a9YAsH79ej772c/y6Ec/mpUrVzI1NcXU1BR77rknV1555X0uexjCHn6SJcAngedX1TcHvT5JmhSrV6/miCOOYNOmTdx1110ce+yxHHPMMQNb3zDe/7wReCDw37qjzRtnG1+SpJbsu+++fP3rX9/ifaampuZtfQMr/Kraq5s8qfuRJI2Qn7SVpEaMzyHtGbZbuIAb+rja3DhZsWJFX0fmx8mkZZ60vGDmYZnEzMPmHr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDWir8JPsm+S+3fTy7rvq915oMkkSfOq3z38C4BNSZYC7wX2Bj46sFSSpHnXb+HfVVUbgWcA51TVq4AHDy6WJGm+9Vv4G5IcB7wQ+D/dvIWDiSRJGoR+C/8E4HHAm6vqxiR7Ax8eXCxJ0nzr6/LIVfWNJK8FlnS3bwTeOshgkqT51e9ZOk8FrgIu6W4flGT5AHNJkuZZv0M6ZwCHAmsAquoqemfqSJImRL+Fv7Gq7pgxr+Y7jCRpcPr9isPrkjwXWJDk4cApwJcGF0uSNN/63cN/OfAo4Bf0PnB1B/DKAWWSJA3AnHv4SRYAy6vqSOANg48kSRqEOffwq2oT8LMkOw0hjyRpQPodw/85cG2SS4F1d8+sqlMGkkqSNO/6LfxPdT+SpAnV7ydtPzDoIJKkweqr8JPcyGbOu6+qfeY9kSRpIPod0jlk2vS2wLOAXeY/jiRpUPo6D7+qVk/7uaWqzgH+cLDRJEnzqd8hnYOn3dyG3h7/DgNJJEkaiH6HdM6eNr0RuBE4dv7jSJIGpd/CP7Gqvjt9RvclKJKkCdHvtXQ+0ec8SdKY2uIefpJH0Lto2k5J/njaoh3pna0jSZoQcw3p/C5wDLAz8NRp838KvHhAmSRJA7DFwq+qi4CLkjyuqr48pEySpAHo96Dt15P8Ob3hnV8O5VTViwaSSpI07/o9aPsh4EHAHwFfAPakN6wjSZoQ/Rb+0qo6DVjXXUjtaGD/wcWSJM23fgt/Q/fvmiT7ATsBew0kkSRpIPodw39PkgcApwHLgUXAGweWSpI07/q9Hv553eQXAC+JLEkTqK8hnSS7J3lvkv/b3X5kkhMHG02SNJ/6HcP/H8CngYd0t78JvHIAeSRJA9Jv4e9aVR8H7gKoqo3ApoGlkiTNu34Lf12SB9J9zWGS3wfuGFgqSdK86/csnf9M7+ycfZP8C7Ab8MyBpZIkzbu5rpa5pKq+X1VXJnkCvYupBbihqjZs6bGSpPEy15DOhdOmP1ZV/1ZV11n2kjR55ir8TJv2/HtJmmBzFX7NMi1JmjBzHbQ9MMlP6O3pb9dN092uqtpxoOkkSfNmri9AWTCsIJKkwer3PHxJ0oSz8CWpERa+JDXCwpekRlj4ktSIfq+lM3TrN2xir7/81KhjbJVX77+R4808q6m3Hj2U9UjaPPfwJakRFr4kNcLC19h50YtexOLFi9lvv/1+Oe+0007jgAMO4KSTTuKoo47i1ltvHWFCaTINrPCTnJLk+iQ/TnJNkquSXJ7ksEGtU78Zjj/+eC655JJ7zDv11FO55pprOO+88zjmmGN405veNKJ00uQa5EHblwFPBn4IrKuqSnIA8HHgEQNcrybc4YcfztTU1D3m7bjjry7btG7dOpIgaesMpPCTvIve5ZSXA++rqnd2i7bHq27qXnrDG97Aueeey+LFi/n85z8/6jjSxEnVYPo3yRRwSFWtSvIM4C3AYuDoqvryLI95CfASgF133e0xbzzn3IFkG5Tdt4Pb1486xdYZZub999ip7/v+4Ac/4HWvex3vf//77zF/7dq1XHTRRdx5552ccMIJ8x1xINauXcuiRYtGHWOrmHk4BpH5iCOOuKKqDtncsqEU/rR5hwNvrKoj53r8kn2W1jbH/v1Asg3Kq/ffyNnXju1HGzZrmJm35jz8qakpjjnmGK677rp7zF+xYgV77703Rx999K8tG1crVqxg2bJlo46xVcw8HIPInGTWwh/qWTpVdRm9L0LfdZjr1eT71re+9cvp5cuX84hHeBhI2loD37VLshT4TnfQ9mDgt4DVg16vJtdxxx3HihUrWLVqFXvuuSdnnnkmF198MTfccAPr16/nkY98JO9617tGHVOaOMN4L/8nwAuSbADWA8+uQY0j6TfC+eef/2vzTjzxRGAy37ZL42JghV9Ve3WTf9f9SJJGyE/aSlIjLHxJasTYnkO43cIF3DBhl9NdsWIFU89bNuoYW2USM0u6d9zDl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1IlU16gybleSnwA2jzrGVdgVWjTrEVpq0zJOWF8w8LGbueVhV7ba5Bfeb5xXNpxuq6pBRh9gaSS4382BNWl4w87CYeW4O6UhSIyx8SWrEOBf+e0Yd4F4w8+BNWl4w87CYeQ5je9BWkjS/xnkPX5I0jyx8SWrEWBZ+kicluSHJt5P85ajz9CPJVJJrk1yV5PJR55kpyfuSrExy3bR5uyS5NMm3un8fMMqMM82S+Ywkt3Tb+aokTxllxpmSPDTJ55Ncn+Tfkryimz+W23oLecd2OyfZNslXk1zdZT6zmz+W2xi2mHmo23nsxvCTLAC+CTwRuBn4GnBcVX1jpMHmkGQKOKSqxvKDH0kOB9YCH6yq/bp5bwN+VFVv7f6wPqCqXjvKnNPNkvkMYG1VnTXKbLNJ8mDgwVV1ZZIdgCuApwPHM4bbegt5j2VMt3OSANtX1dokC4EvAq8A/pgx3MawxcxPYojbeRz38A8Fvl1V362qO4H/CTxtxJkmXlVdBvxoxuynAR/opj9A7xd9bMySeaxV1W1VdWU3/VPgemAPxnRbbyHv2Kqetd3Nhd1PMabbGLaYeajGsfD3AG6advtmxvx/wE4Bn0lyRZKXjDpMn3avqtug94sPLB5xnn6dnOSabshnbN62z5RkL+DRwL8yAdt6Rl4Y4+2cZEGSq4CVwKVVNfbbeJbMMMTtPI6Fn83MG69xp817fFUdDDwZ+PNuOELz778D+wIHAbcBZ480zSySLAIuAF5ZVT8ZdZ65bCbvWG/nqtpUVQcBewKHJtlvxJHmNEvmoW7ncSz8m4GHTru9J3DriLL0rapu7f5dCfwTvaGpcXd7N4Z791juyhHnmVNV3d794twFnMsYbudujPYC4CNV9clu9thu683lnYTtDFBVa4AV9MbCx3YbTzc987C38zgW/teAhyfZO8lvAc8Blo840xYl2b474EWS7YGjgOu2/KixsBx4YTf9QuCiEWbpy92/0J1nMGbbuTs4917g+qp6x7RFY7mtZ8s7zts5yW5Jdu6mtwOOBP6dMd3GMHvmYW/nsTtLB6A7NekcYAHwvqp682gTbVmSfejt1UPvCqQfHbfMSc4HltG7HOvtwOnAhcDHgSXA94FnVdXYHCSdJfMyem9/C5gC/uzucdtxkOQw4J+Ba4G7utmvpzcuPnbbegt5j2NMt3OSA+gdlF1Ab6f141X1piQPZAy3MWwx84cY4nYey8KXJM2/cRzSkSQNgIUvSY2w8CWpERa+JDXCwpekRozzl5hLA5FkE73TEO/29KqaGlEcaWg8LVPNSbK2qhYNcX33q6qNw1qfNBuHdKQZkjw4yWXd9cmvS/IH3fwnJbmyu6b5/+vm7ZLkwu7iV1/pPmBz93XO35PkM8AHu09aXpDka93P40f4EtUoh3TUou26qxYC3FhVz5ix/LnAp6vqzd33M/x2kt3oXevk8Kq6Mcku3X3PBL5eVU9P8ofAB+l9chLgMcBhVbU+yUeBd1bVF5MsAT4N/N7AXqG0GRa+WrS+u2rhbL4GvK+7qNiFVXVVkmXAZVV1I8C0j+wfBvxJN+9zSR6YZKdu2fKqWt9NHwk8snfpGgB2TLJDdw16aSgsfGmGqrqsu7z10cCHkrwdWMPmL9O9pct5r5s2bxvgcdP+AEhD5xi+NEOShwErq+pceleSPBj4MvCEJHt397l7SOcy4HndvGXAqlmuf/8Z4ORp6zhoQPGlWbmHL/26ZcCpSTbQ+07dF1TVD7tvMvtkkm3oXWv9icAZwPuTXAP8jF9dnnemU4B/7O53P3p/KF460FchzeBpmZLUCId0JKkRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqxP8HeF5Uwdd3eicAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.datasets import load_iris\n", "import xgboost as xgb\n", "from xgboost import plot_importance\n", "from matplotlib import pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import f1_score # 准确率\n", "# 加载样本数据集\n", "iris = load_iris()\n", "X,y = iris.data,iris.target\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234565) # 数据集分割\n", "# 算法参数\n", "params = {\n", " 'booster': 'gbtree',\n", " 'objective': 'multi:softmax',\n", " 'eval_metric':'mlogloss',\n", " 'num_class': 3,\n", " 'gamma': 0.1,\n", " 'max_depth': 6,\n", " 'lambda': 2,\n", " 'subsample': 0.7,\n", " 'colsample_bytree': 0.75,\n", " 'min_child_weight': 3,\n", " 'eta': 0.1,\n", " 'seed': 1,\n", " 'nthread': 4,\n", "}\n", "\n", "# plst = params.items()\n", "\n", "train_data = xgb.DMatrix(X_train, y_train) # 生成数据集格式\n", "num_rounds = 500\n", "model = xgb.train(params, train_data) # xgboost模型训练\n", "\n", "# 对测试集进行预测\n", "dtest = xgb.DMatrix(X_test)\n", "y_pred = model.predict(dtest)\n", "\n", "# 计算准确率\n", "F1_score = f1_score(y_test,y_pred,average='macro')\n", "print(\"F1_score: %.2f%%\" % (F1_score*100.0))\n", "\n", "# 显示重要特征\n", "plot_importance(model)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://blog.csdn.net/han_xiaoyang/article/details/52665396\n", "https://www.cnblogs.com/TimVerion/p/11436001.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 交叉验证" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "交叉验证是验证分类器性能的一种统计分析方法,其基本思想在某种意义下将原始数据进行分组,一部分作为训练集,另一部分作为验证集。首先是用训练集对分类器进行训练,再利用验证集来测试所得到的的模型,以此来作为评价分类器的性能指标。常用的交叉验证方法包括简单交叉验证、K折交叉验证、留一法交叉验证和留P法交叉验证" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1.简单交叉验证(cross validation) \n", "简单交叉验证是将原始数据分为两组,一组作为训练集,另一组作为验证集,利用训练集训练分类器,然后利用验证集验证模型,将最后的分类准确率作为此分类器的性能指标。通常是划分30%的数据作为测试数据\n", "2.K折交叉验证(K-Fold cross validation) \n", "K折交叉验证是将原始数据分为K组,然后将每个子集数据分别做一次验证集,其余的K-1组子集作为训练集,这样就会得到K个模型,将K个模型最终的验证集的分类准确率取平均值,作为K折交叉验证分类器的性能指标。通常设置为K为5或者10. \n", "3.留一法交叉验证(Leave-One-Out Cross Validation,LOO-CV)\n", "留一法交叉验证是指每个训练集由除一个样本之外的其余样本组成,留下的一个样本组成检验集。这样对于N个样本的数据集,可以组成N个不同的训练集和N个不同的验证集,因此该方法会得到N个模型,用N个模型最终的验证集的分类准确率的平均是作为分类器的性能指标。 \n", "4.留P法交叉验证 \n", "该方法与留一法类似,是从完整数据集中删除P个样本,产生所有可能的训练集和验证集。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "交叉验证示例代码" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1.简单交叉验证" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn import datasets\n", "#数据集导入\n", "iris=datasets.load_iris()\n", "feature=iris.feature_names\n", "X = iris.data\n", "y = iris.target\n", "X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.4,random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2.K折交叉验证" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import KFold\n", "folds = KFold(n_splits=10, shuffle=is_shuffle)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3.留一法交叉验证" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import LeaveOneOut\n", "loo=LeaveOneOut()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "4.留P法交叉验证" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import LeavePOut\n", "lpo=LeavePOut(p=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "另外还有一些其他交叉验证的分割方法,如基于类标签,具有分层的交叉验证。这一类交叉验证方法主要用于解决样本不平衡的问题。 \n", "这种情况下常用StratifiedKFold和StratifiedShuffleSplit的分层抽样方法,可以确保相应的类别频率在每个训练和验证的(fold)中得以保留。 \n", "StratifiedKFold:是K-fold的变种,会返回stratified(分层)的折叠:每个小集合中的各个类别的样本比例大致和完整数据集相同。 \n", "StratifiedShuffleSplit:是ShuffleSplit的一种变种,会返回直接的划分,比如创建一个划分,但是划分中的每个类的比例和完整数据集中的相同。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型调参\n", "调参就是对模型的参数进行调整,找到使模型最优的超参数,调参的目标就是尽可能达到整体模型的最优" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1.网格搜索 \n", "网格搜索就是一种穷举搜索,在所有候选的参数选择中通过循环遍历去在所有候选参数中寻找表现最好的结果。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2.学习曲线 \n", "学习曲线是在训练集大小不同时通过绘制模型训练集和交叉验证集上的准确率来观察模型在新数据上的表现,进而来判断模型是否方差偏高或偏差过高,以及增大训练集是否可以减小过拟合。 \n", "\n", "\"img\"\n", "1、当训练集和测试集的误差收敛但却很高时,为高偏差。 \n", "\n", "左上角的偏差很高,训练集和验证集的准确率都很低,很可能是欠拟合。 \n", "我们可以增加模型参数,比如,构建更多的特征,减小正则项。 \n", "此时通过增加数据量是不起作用的。\n", " \n", "2、当训练集和测试集的误差之间有大的差距时,为高方差。 \n", "\n", "当训练集的准确率比其他独立数据集上的测试结果的准确率要高时,一般都是过拟合。 \n", "右上角方差很高,训练集和验证集的准确率相差太多,应该是过拟合。 \n", "我们可以增大训练集,降低模型复杂度,增大正则项,或者通过特征选择减少特征数。 \n", "理想情况是是找到偏差和方差都很小的情况,即收敛且误差较小。" ] }, { "attachments": { "image.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "3.验证曲线\n", "和学习曲线不同,验证曲线的横轴为某个超参数的一系列值,由此比较不同超参数设置下的模型准确值。从下图的验证曲线可以看到,随着超参数设置的改变,模型可能会有从欠拟合到合适再到过拟合的过程,进而可以选择一个合适的超参数设置来提高模型的性能。\n", "![image.png](attachment:image.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#以Xgboost为例,该网格搜索代码示例如下\n", "import xgboost as xgb\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.datasets import load_breast_cancer\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "cancer = load_breast_cancer()\n", "x = cancer.data[:50]\n", "y = cancer.target[:50]\n", "train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.333, random_state=0) # 分训练集和验证集\n", "# 这里不需要Dmatrix\n", "\n", "parameters = {\n", " 'max_depth': [5, 10, 15, 20, 25],\n", " 'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.15],\n", " 'n_estimators': [50, 100, 200, 300, 500],\n", " 'min_child_weight': [0, 2, 5, 10, 20],\n", " 'max_delta_step': [0, 0.2, 0.6, 1, 2],\n", " 'subsample': [0.6, 0.7, 0.8, 0.85, 0.95],\n", " 'colsample_bytree': [0.5, 0.6, 0.7, 0.8, 0.9],\n", " 'reg_alpha': [0, 0.25, 0.5, 0.75, 1],\n", " 'reg_lambda': [0.2, 0.4, 0.6, 0.8, 1],\n", " 'scale_pos_weight': [0.2, 0.4, 0.6, 0.8, 1]\n", "}\n", "\n", "xlf = xgb.XGBClassifier(max_depth=10,\n", " learning_rate=0.01,\n", " n_estimators=2000,\n", " silent=True,\n", " objective='binary:logistic',\n", " nthread=-1,\n", " gamma=0,\n", " min_child_weight=1,\n", " max_delta_step=0,\n", " subsample=0.85,\n", " colsample_bytree=0.7,\n", " colsample_bylevel=1,\n", " reg_alpha=0,\n", " reg_lambda=1,\n", " scale_pos_weight=1,\n", " seed=1440,\n", " missing=None)\n", "\n", "# 有了gridsearch我们便不需要fit函数\n", "gsearch = GridSearchCV(xlf, param_grid=parameters, scoring='accuracy', cv=3)\n", "gsearch.fit(train_x, train_y)\n", "\n", "print(\"Best score: %0.3f\" % gsearch.best_score_)\n", "print(\"Best parameters set:\")\n", "best_parameters = gsearch.best_estimator_.get_params()\n", "for param_name in sorted(parameters.keys()):\n", " print(\"\\t%s: %r\" % (param_name, best_parameters[param_name]))\n", "#极其耗费时间,电脑没执行完" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 智慧海洋数据集模型代码示例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## lightGBM模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from tqdm import tqdm\n", "from sklearn.metrics import classification_report, f1_score\n", "from sklearn.model_selection import StratifiedKFold, KFold,train_test_split\n", "import lightgbm as lgb\n", "import os\n", "import warnings\n", "from hyperopt import fmin, tpe, hp, STATUS_OK, Trials" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "all_df=pd.read_csv('group_df.csv',index_col=0)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "use_train = all_df[all_df['label'] != -1]\n", "use_test = all_df[all_df['label'] == -1]#label为-1时是测试集\n", "use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]\n", "X_train,X_verify,y_train,y_verify= train_test_split(use_train[use_feats],use_train['label'],test_size=0.3,random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1.根据特征的重要性进行特征选择" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\SOFTWEAR_H\\Anaconda3\\lib\\site-packages\\lightgbm\\callback.py:186: UserWarning: Early stopping is not available in dart mode\n", " warnings.warn('Early stopping is not available in dart mode')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "total feature best score: defaultdict(, {'valid': OrderedDict([('f1_score', 0.9004541298211368)])})\n", "total feature importance: [('pos_neq_zero_speed_q_40', 1783), ('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369), ('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194), ('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963), ('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874), ('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70', 856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772), ('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x', 729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693), ('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean', 626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607), ('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572), ('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean', 557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503), ('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490), ('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x', 470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique', 460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431), ('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420), ('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407), ('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397), ('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384), ('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367), ('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364), ('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341), ('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332), ('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326), ('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323), ('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean', 312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20', 304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max', 301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300), ('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294), ('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291), ('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283), ('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279), ('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271), ('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264), ('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258), ('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248), ('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241), ('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238), ('w2v_14_mean', 236), ('lat_nunique', 235), ('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216), ('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214), ('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212), ('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208), ('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203), ('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201), ('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x', 197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193), ('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y', 191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x', 187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean', 184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean', 181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180), ('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178), ('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174), ('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172), ('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171), ('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x', 167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163), ('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162), ('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157), ('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154), ('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151), ('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148), ('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146), ('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y', 144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142), ('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x', 136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136), ('direction_max', 135), ('pos_neq_zero_lon_median', 134), ('lat_lon_tfidf_27_x', 134), ('lon_q_80', 133), ('lat_lon_countvec_15_x', 133), ('pos_neq_zero_lon_max', 132), ('lat_lon_countvec_14_x', 132), ('lat_lon_tfidf_26_x', 131), ('grad_tfidf_19_x', 131), ('sample_tfidf_8_x', 131), ('lat_q_60', 130), ('sample_tfidf_28_x', 130), ('lat_lon_countvec_27_y', 130), ('lat_lon_countvec_6_x', 128), ('lat_lon_countvec_0_y', 128), ('sample_tfidf_12_y', 127), ('lat_lon_tfidf_8_x', 126), ('sample_tfidf_29_y', 126), ('lat_lon_countvec_17_x', 125), ('direction_q_70', 124), ('lat_lon_tfidf_20_y', 124), ('lat_lon_tfidf_3_x', 121), ('sample_tfidf_21_y', 120), ('grad_tfidf_0_y', 119), ('pos_neq_zero_lat_median', 118), ('lat_lon_tfidf_16_y', 118), ('grad_tfidf_10_x', 117), ('sample_tfidf_2_x', 116), ('lat_lon_countvec_4_y', 116), ('speed_median', 115), ('pos_neq_zero_direction_q_10', 115), ('speed_neq_zero_lon_mean', 115), ('pos_neq_zero_direction_max', 114), ('lat_q_40', 113), ('grad_tfidf_1_x', 113), ('speed_nunique', 111), ('sample_tfidf_23_y', 111), ('speed_q_30', 110), ('pos_neq_zero_lat_q_30', 110), ('lat_lon_tfidf_10_x', 110), ('lat_lon_countvec_10_x', 110), ('lat_lon_tfidf_23_y', 109), ('pos_neq_zero_speed_min', 106), ('speed_neq_zero_lat_q_60', 106), ('lat_lon_countvec_21_x', 106), ('lat_lon_countvec_18_y', 106), ('lat_lon_tfidf_17_x', 105), ('grad_tfidf_8_y', 103), ('grad_tfidf_6_y', 102), ('sample_tfidf_10_y', 101), ('pos_neq_zero_lon_min', 100), ('lat_lon_countvec_8_x', 100), ('lat_lon_countvec_9_y', 100), ('direction_mean', 99), ('grad_tfidf_21_y', 99), ('lat_lon_tfidf_6_x', 98), ('lat_lon_tfidf_18_y', 97), ('direction_q_80', 96), ('pos_neq_zero_direction_q_70', 96), ('lat_lon_countvec_20_x', 95), ('speed_neq_zero_direction_q_70', 93), ('lat_lon_countvec_25_x', 93), ('lat_lon_countvec_23_x', 92), ('lat_lon_tfidf_14_y', 92), ('lat_q_90', 91), ('sample_tfidf_7_x', 91), ('pos_neq_zero_lon_q_70', 90), ('lat_lon_countvec_5_y', 90), ('pos_neq_zero_direction_q_20', 89), ('lat_lon_tfidf_12_y', 89), ('lat_lon_tfidf_28_y', 89), ('sample_tfidf_4_y', 89), ('direction_q_40', 88), ('pos_neq_zero_lat_q_20', 87), ('grad_tfidf_17_x', 87), ('sample_tfidf_9_y', 87), ('sample_tfidf_24_y', 87), ('pos_neq_zero_lat_max', 86), ('pos_neq_zero_lon_mean', 86), ('speed_neq_zero_direction_q_40', 86), ('lat_lon_countvec_7_x', 86), ('speed_neq_zero_speed_q_40', 85), ('sample_tfidf_6_x', 84), ('sample_tfidf_19_y', 84), ('speed_min', 83), ('direction_q_10', 83), ('lat_lon_countvec_19_x', 83), ('grad_tfidf_24_x', 83), ('speed_q_60', 82), ('lat_lon_tfidf_25_x', 82), ('sample_tfidf_3_x', 82), ('grad_tfidf_22_y', 82), ('direction_q_30', 80), ('speed_neq_zero_direction_mean', 80), ('grad_tfidf_18_y', 77), ('lat_q_10', 76), ('speed_neq_zero_speed_max', 75), ('grad_tfidf_3_y', 75), ('sample_tfidf_11_y', 75), ('lon_nunique', 74), ('lon_q_90', 74), ('speed_neq_zero_lon_q_10', 74), ('speed_neq_zero_speed_median', 74), ('grad_tfidf_28_x', 74), ('grad_tfidf_20_y', 74), ('speed_neq_zero_lon_q_70', 73), ('lat_lon_tfidf_24_y', 73), ('pos_neq_zero_lat_q_60', 72), ('lat_lon_countvec_2_x', 72), ('lat_lon_countvec_3_x', 69), ('sample_tfidf_20_y', 69), ('lat_lon_tfidf_13_x', 68), ('grad_tfidf_16_y', 68), ('sample_tfidf_13_y', 67), ('speed_neq_zero_lon_q_30', 66), ('speed_q_40', 65), ('grad_tfidf_4_y', 65), ('sample_tfidf_5_y', 65), ('lat_q_30', 64), ('pos_neq_zero_direction_median', 64), ('speed_neq_zero_lat_median', 64), ('grad_tfidf_11_y', 64), ('grad_tfidf_27_y', 64), ('lat_lon_tfidf_19_y', 62), ('pos_neq_zero_lon_q_40', 61), ('lat_lon_countvec_26_y', 61), ('pos_neq_zero_lon_q_80', 60), ('sample_tfidf_17_y', 60), ('lon_q_40', 59), ('lat_lon_countvec_28_y', 59), ('lat_lon_tfidf_22_y', 57), ('grad_tfidf_29_x', 56), ('lat_lon_countvec_12_y', 56), ('sample_tfidf_15_y', 56), ('sample_tfidf_27_y', 56), ('speed_q_70', 55), ('lat_lon_tfidf_21_y', 55), ('grad_tfidf_9_y', 55), ('sample_tfidf_25_y', 55), ('pos_neq_zero_direction_mean', 54), ('sample_tfidf_26_x', 54), ('sample_tfidf_18_y', 53), ('speed_neq_zero_lon_q_90', 51), ('speed_neq_zero_direction_max', 51), ('lat_lon_tfidf_5_y', 50), ('pos_neq_zero_direction_q_60', 49), ('sample_tfidf_2_y', 49), ('pos_neq_zero_lon_q_60', 48), ('speed_neq_zero_speed_mean', 48), ('lat_lon_tfidf_15_y', 48), ('pos_neq_zero_direction_q_30', 47), ('speed_neq_zero_lon_nunique', 47), ('lat_lon_countvec_24_x', 47), ('sample_tfidf_8_y', 47), ('lat_lon_tfidf_10_y', 46), ('lon_q_60', 45), ('pos_neq_zero_lat_q_70', 45), ('speed_neq_zero_direction_q_10', 45), ('lat_lon_tfidf_3_y', 45), ('speed_neq_zero_lat_mean', 43), ('speed_neq_zero_lat_q_80', 43), ('lat_lon_tfidf_2_y', 43), ('lat_lon_tfidf_8_y', 43), ('grad_tfidf_19_y', 43), ('grad_tfidf_25_y', 43), ('grad_tfidf_26_y', 43), ('lon_q_30', 42), ('speed_neq_zero_lon_q_20', 42), ('pos_neq_zero_speed_nunique', 41), ('speed_neq_zero_speed_nunique', 41), ('speed_neq_zero_speed_q_30', 41), ('lat_lon_tfidf_7_y', 41), ('lat_lon_tfidf_17_y', 41), ('lat_lon_countvec_14_y', 41), ('grad_tfidf_10_y', 41), ('lat_lon_tfidf_26_y', 40), ('grad_tfidf_1_y', 40), ('speed_neq_zero_lat_q_20', 39), ('speed_q_80', 38), ('speed_neq_zero_lat_q_30', 38), ('lat_lon_countvec_15_y', 38), ('pos_neq_zero_direction_q_40', 37), ('speed_neq_zero_direction_median', 37), ('pos_neq_zero_lon_q_30', 36), ('lat_lon_countvec_11_y', 36), ('lat_lon_countvec_21_y', 35), ('sample_tfidf_28_y', 35), ('speed_neq_zero_speed_q_60', 34), ('lat_lon_countvec_29_y', 34), ('sample_tfidf_1_y', 34), ('sample_tfidf_22_y', 34), ('lat_lon_countvec_6_y', 33), ('lat_lon_countvec_10_y', 33), ('lat_lon_countvec_16_y', 33), ('speed_mean', 32), ('lat_lon_countvec_17_y', 31), ('lat_lon_countvec_23_y', 31), ('speed_neq_zero_direction_q_30', 30), ('lat_lon_tfidf_13_y', 30), ('sample_tfidf_16_y', 30), ('speed_neq_zero_lat_q_10', 29), ('lat_lon_tfidf_27_y', 29), ('grad_tfidf_17_y', 29), ('lat_lon_countvec_13_x', 27), ('lat_lon_countvec_19_y', 27), ('grad_tfidf_24_y', 26), ('speed_neq_zero_lon_q_40', 25), ('lat_lon_tfidf_25_y', 25), ('lat_lon_countvec_8_y', 25), ('speed_neq_zero_lon_median', 24), ('speed_neq_zero_speed_min', 24), ('lat_lon_countvec_25_y', 24), ('sample_tfidf_6_y', 24), ('pos_neq_zero_lat_nunique', 23), ('speed_neq_zero_lon_q_80', 23), ('lat_lon_countvec_20_y', 23), ('speed_neq_zero_speed_q_10', 22), ('lat_lon_countvec_3_y', 22), ('grad_tfidf_28_y', 22), ('sample_tfidf_7_y', 22), ('lat_lon_countvec_7_y', 21), ('sample_tfidf_26_y', 21), ('lat_lon_tfidf_6_y', 20), ('sample_tfidf_3_y', 20), ('grad_tfidf_29_y', 18), ('speed_neq_zero_lon_q_60', 16), ('speed_neq_zero_speed_q_20', 14), ('lat_lon_countvec_24_y', 14), ('lat_lon_countvec_2_y', 11), ('speed_neq_zero_direction_q_20', 9), ('lat_lon_countvec_13_y', 8), ('speed_q_10', 7), ('pos_neq_zero_lon_nunique', 5), ('direction_q_20', 4), ('speed_q_20', 2), ('pos_neq_zero_direction_min', 2), ('direction_min', 0), ('speed_neq_zero_direction_min', 0)]\n", "select forward 200 features:[('pos_neq_zero_speed_q_40', 1783), ('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369), ('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194), ('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963), ('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874), ('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70', 856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772), ('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x', 729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693), ('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean', 626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607), ('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572), ('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean', 557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503), ('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490), ('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x', 470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique', 460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431), ('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420), ('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407), ('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397), ('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384), ('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367), ('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364), ('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341), ('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332), ('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326), ('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323), ('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean', 312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20', 304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max', 301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300), ('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294), ('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291), ('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283), ('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279), ('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271), ('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264), ('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258), ('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248), ('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241), ('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238), ('w2v_14_mean', 236), ('lat_nunique', 235), ('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216), ('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214), ('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212), ('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208), ('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203), ('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201), ('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x', 197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193), ('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y', 191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x', 187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean', 184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean', 181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180), ('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178), ('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174), ('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172), ('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171), ('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x', 167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163), ('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162), ('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157), ('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154), ('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151), ('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148), ('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146), ('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y', 144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142), ('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x', 136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136)]\n" ] } ], "source": [ "##############特征选择参数###################\n", "selectFeatures = 200 # 控制特征数\n", "earlyStopping = 100 # 控制早停\n", "select_num_boost_round = 1000 # 特征选择训练轮次\n", "#首先设置基础参数\n", "selfParam = {\n", " 'learning_rate':0.01, # 学习率\n", " 'boosting':'dart', # 算法类型, gbdt,dart\n", " 'objective':'multiclass', # 多分类\n", " 'metric':'None',\n", " 'num_leaves':32, # \n", " 'feature_fraction':0.7, # 训练特征比例\n", " 'bagging_fraction':0.8, # 训练样本比例 \n", " 'min_data_in_leaf':30, # 叶子最小样本\n", " 'num_class': 3,\n", " 'max_depth':6, # 树的最大深度\n", " \n", " 'num_threads':8,#LightGBM 的线程数\n", " 'min_data_in_bin':30, # 单箱数据量\n", " 'max_bin':256, # 最大分箱数 \n", " 'is_unbalance':True, # 非平衡样本\n", " 'train_metric':True,\n", " 'verbose':-1,\n", "}\n", "# 特征选择 ---------------------------------------------------------------------------------\n", "def f1_score_eval(preds, valid_df):\n", " labels = valid_df.get_label()\n", " preds = np.argmax(preds.reshape(3, -1), axis=0)\n", " scores = f1_score(y_true=labels, y_pred=preds, average='macro')\n", " return 'f1_score', scores, True\n", "\n", "train_data = lgb.Dataset(data=X_train,label=y_train,feature_name=use_feats)\n", "valid_data = lgb.Dataset(data=X_verify,label=y_verify,reference=train_data,feature_name=use_feats)\n", "\n", "sm = lgb.train(params=selfParam,train_set=train_data,num_boost_round=select_num_boost_round,\n", " valid_sets=[valid_data],valid_names=['valid'],\n", " feature_name=use_feats,\n", " early_stopping_rounds=earlyStopping,verbose_eval=False,keep_training_booster=True,feval=f1_score_eval)\n", "features_importance = {k:v for k,v in zip(sm.feature_name(),sm.feature_importance(iteration=sm.best_iteration))}\n", "sort_feature_importance = sorted(features_importance.items(),key=lambda x:x[1],reverse=True)\n", "print('total feature best score:', sm.best_score)\n", "print('total feature importance:',sort_feature_importance)\n", "print('select forward {} features:{}'.format(selectFeatures,sort_feature_importance[:selectFeatures]))\n", "#model_feature是选择的超参数\n", "model_feature = [k[0] for k in sort_feature_importance[:selectFeatures]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[贝叶斯优化介绍](https://github.com/FontTian/hyperopt-doc-zh/wiki/FMin)也是在建模调参过程中常用的一种方法,下面是通过贝叶斯优化进行超参数选择的代码" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "训练集f1_score:1.0,测试集f1_score:0.9238060849905194,loss_f1_score:0.07619391500948058 \n", "训练集f1_score:0.9414337502771342,测试集f1_score:0.8878751759836653,loss_f1_score:0.11212482401633472 \n", "训练集f1_score:1.0,测试集f1_score:0.9275451088133652,loss_f1_score:0.07245489118663484 \n", "训练集f1_score:1.0,测试集f1_score:0.9262405937033683,loss_f1_score:0.07375940629663169 \n", "训练集f1_score:0.9708237804866381,测试集f1_score:0.9105982243190386,loss_f1_score:0.08940177568096142 \n", "训练集f1_score:0.9689912364726484,测试集f1_score:0.9086459359345839,loss_f1_score:0.09135406406541613 \n", "训练集f1_score:0.9841597696688008,测试集f1_score:0.9027075194168233,loss_f1_score:0.09729248058317674 \n", "训练集f1_score:1.0,测试集f1_score:0.9215512877825286,loss_f1_score:0.0784487122174714 \n", "训练集f1_score:1.0,测试集f1_score:0.924555451978199,loss_f1_score:0.075444548021801 \n", "训练集f1_score:0.998357894114157,测试集f1_score:0.9157797895654226,loss_f1_score:0.08422021043457739 \n", "训练集f1_score:1.0,测试集f1_score:0.9225868784774544,loss_f1_score:0.07741312152254565 \n", "训练集f1_score:1.0,测试集f1_score:0.9188521505717673,loss_f1_score:0.08114784942823272 \n", "训练集f1_score:0.9268245763808158,测试集f1_score:0.8763935795977332,loss_f1_score:0.12360642040226677 \n", "训练集f1_score:1.0,测试集f1_score:0.9215959099478135,loss_f1_score:0.07840409005218651 \n", "训练集f1_score:1.0,测试集f1_score:0.9265015559936258,loss_f1_score:0.07349844400637418 \n", "训练集f1_score:1.0,测试集f1_score:0.9143628354188641,loss_f1_score:0.0856371645811359 \n", "训练集f1_score:1.0,测试集f1_score:0.9202754009210264,loss_f1_score:0.07972459907897356 \n", "训练集f1_score:0.9550283459834631,测试集f1_score:0.8923546584333147,loss_f1_score:0.10764534156668526 \n", "训练集f1_score:1.0,测试集f1_score:0.9255732985564632,loss_f1_score:0.0744267014435368 \n", "训练集f1_score:1.0,测试集f1_score:0.926093875740129,loss_f1_score:0.07390612425987098 \n", "训练集f1_score:1.0,测试集f1_score:0.9275189170142104,loss_f1_score:0.07248108298578959 \n", "训练集f1_score:1.0,测试集f1_score:0.9257895202231272,loss_f1_score:0.07421047977687278 \n", "训练集f1_score:1.0,测试集f1_score:0.9248738969479765,loss_f1_score:0.0751261030520235 \n", "训练集f1_score:1.0,测试集f1_score:0.9272520229049039,loss_f1_score:0.07274797709509606 \n", "训练集f1_score:1.0,测试集f1_score:0.9256769527801775,loss_f1_score:0.07432304721982252 \n", "训练集f1_score:1.0,测试集f1_score:0.9252959646692677,loss_f1_score:0.07470403533073233 \n", "训练集f1_score:1.0,测试集f1_score:0.9280536344383128,loss_f1_score:0.07194636556168721 \n", "训练集f1_score:1.0,测试集f1_score:0.9316114105930104,loss_f1_score:0.06838858940698955 \n", "训练集f1_score:1.0,测试集f1_score:0.9282603014798921,loss_f1_score:0.07173969852010786 \n", "训练集f1_score:1.0,测试集f1_score:0.9169851848129301,loss_f1_score:0.08301481518706988 \n", "训练集f1_score:0.9998006409358186,测试集f1_score:0.9170084634982812,loss_f1_score:0.08299153650171875 \n", "训练集f1_score:1.0,测试集f1_score:0.919142326688697,loss_f1_score:0.080857673311303 \n", "训练集f1_score:1.0,测试集f1_score:0.927350422658861,loss_f1_score:0.07264957734113897 \n", "训练集f1_score:1.0,测试集f1_score:0.9248086877712395,loss_f1_score:0.07519131222876052 \n", "训练集f1_score:1.0,测试集f1_score:0.9170626453496801,loss_f1_score:0.08293735465031993 \n", "训练集f1_score:1.0,测试集f1_score:0.9277641923766077,loss_f1_score:0.07223580762339232 \n", "训练集f1_score:1.0,测试集f1_score:0.9221988666312404,loss_f1_score:0.0778011333687596 \n", "训练集f1_score:1.0,测试集f1_score:0.9225220095934339,loss_f1_score:0.07747799040656611 \n", "训练集f1_score:1.0,测试集f1_score:0.9239565521812777,loss_f1_score:0.0760434478187223 \n", "训练集f1_score:1.0,测试集f1_score:0.9276828960144917,loss_f1_score:0.07231710398550828 \n", "训练集f1_score:1.0,测试集f1_score:0.9205931627810685,loss_f1_score:0.07940683721893149 \n", "训练集f1_score:1.0,测试集f1_score:0.9262928923256212,loss_f1_score:0.07370710767437882 \n", "训练集f1_score:0.9944566925965641,测试集f1_score:0.9103100448505551,loss_f1_score:0.08968995514944489 \n", "训练集f1_score:1.0,测试集f1_score:0.9267901922541096,loss_f1_score:0.07320980774589037 \n", "训练集f1_score:1.0,测试集f1_score:0.920503002249437,loss_f1_score:0.07949699775056296 \n", "训练集f1_score:0.9315809154440894,测试集f1_score:0.888114739372245,loss_f1_score:0.11188526062775495 \n", "训练集f1_score:1.0,测试集f1_score:0.9312944518110373,loss_f1_score:0.06870554818896268 \n", "训练集f1_score:1.0,测试集f1_score:0.9303459748533314,loss_f1_score:0.06965402514666863 \n", "训练集f1_score:1.0,测试集f1_score:0.931353840440614,loss_f1_score:0.06864615955938602 \n", "训练集f1_score:1.0,测试集f1_score:0.9229280238009058,loss_f1_score:0.07707197619909423 \n", "训练集f1_score:1.0,测试集f1_score:0.9081707271979852,loss_f1_score:0.0918292728020148 \n", "训练集f1_score:1.0,测试集f1_score:0.9263682433473132,loss_f1_score:0.07363175665268684 \n", "训练集f1_score:0.9979810910594639,测试集f1_score:0.9137152734108268,loss_f1_score:0.08628472658917319 \n", "训练集f1_score:1.0,测试集f1_score:0.9258220879299731,loss_f1_score:0.07417791207002689 \n", "训练集f1_score:1.0,测试集f1_score:0.9174454505221505,loss_f1_score:0.08255454947784946 \n", "训练集f1_score:1.0,测试集f1_score:0.9271364668867941,loss_f1_score:0.07286353311320592 \n", "训练集f1_score:1.0,测试集f1_score:0.9147023183361269,loss_f1_score:0.08529768166387308 \n", "训练集f1_score:0.9818127606280159,测试集f1_score:0.9017199309349478,loss_f1_score:0.09828006906505216 \n", "训练集f1_score:1.0,测试集f1_score:0.9144702886766378,loss_f1_score:0.08552971132336218 \n", "训练集f1_score:0.9987361493711533,测试集f1_score:0.9152462742627984,loss_f1_score:0.08475372573720164 \n", "训练集f1_score:1.0,测试集f1_score:0.9283825864164065,loss_f1_score:0.07161741358359353 \n", "训练集f1_score:1.0,测试集f1_score:0.9185245776900096,loss_f1_score:0.08147542230999039 \n", "训练集f1_score:1.0,测试集f1_score:0.9176200948292667,loss_f1_score:0.08237990517073335 \n", "训练集f1_score:0.9993129514194335,测试集f1_score:0.9174352830766729,loss_f1_score:0.08256471692332712 \n", "训练集f1_score:1.0,测试集f1_score:0.9276704131051788,loss_f1_score:0.07232958689482116 \n", "训练集f1_score:1.0,测试集f1_score:0.9268048760558437,loss_f1_score:0.07319512394415628 \n", "训练集f1_score:1.0,测试集f1_score:0.9304568955332027,loss_f1_score:0.06954310446679735 \n", "训练集f1_score:1.0,测试集f1_score:0.9222607611550148,loss_f1_score:0.07773923884498524 \n", "训练集f1_score:1.0,测试集f1_score:0.9303686983620825,loss_f1_score:0.06963130163791753 \n", "训练集f1_score:1.0,测试集f1_score:0.9275281467065163,loss_f1_score:0.07247185329348371 \n", "训练集f1_score:1.0,测试集f1_score:0.9263494542572851,loss_f1_score:0.0736505457427149 \n", "训练集f1_score:1.0,测试集f1_score:0.9262464202510822,loss_f1_score:0.07375357974891783 \n", "训练集f1_score:1.0,测试集f1_score:0.9213298706249988,loss_f1_score:0.07867012937500117 \n", "训练集f1_score:1.0,测试集f1_score:0.9255381820063792,loss_f1_score:0.07446181799362084 \n", "训练集f1_score:1.0,测试集f1_score:0.9262492441399471,loss_f1_score:0.07375075586005286 \n", "训练集f1_score:1.0,测试集f1_score:0.9267529385979496,loss_f1_score:0.0732470614020504 \n", "训练集f1_score:1.0,测试集f1_score:0.9279362552557409,loss_f1_score:0.07206374474425914 \n", "训练集f1_score:1.0,测试集f1_score:0.9105496558898486,loss_f1_score:0.0894503441101514 \n", "训练集f1_score:1.0,测试集f1_score:0.9255677088759965,loss_f1_score:0.07443229112400351 \n", "训练集f1_score:1.0,测试集f1_score:0.9258810998636311,loss_f1_score:0.0741189001363689 \n", "训练集f1_score:1.0,测试集f1_score:0.9236045683410877,loss_f1_score:0.07639543165891227 \n", "训练集f1_score:1.0,测试集f1_score:0.9236482035413927,loss_f1_score:0.07635179645860735 \n", "训练集f1_score:0.9998006409358186,测试集f1_score:0.9161826380576955,loss_f1_score:0.08381736194230449 \n", "训练集f1_score:1.0,测试集f1_score:0.9226427795765888,loss_f1_score:0.0773572204234112 \n", "训练集f1_score:1.0,测试集f1_score:0.9227047668043812,loss_f1_score:0.07729523319561882 \n", "训练集f1_score:1.0,测试集f1_score:0.9255689533534145,loss_f1_score:0.07443104664658551 \n", "训练集f1_score:1.0,测试集f1_score:0.9322007348532765,loss_f1_score:0.06779926514672352 \n", "训练集f1_score:1.0,测试集f1_score:0.9169573599775939,loss_f1_score:0.08304264002240613 \n", "训练集f1_score:1.0,测试集f1_score:0.9230059720988804,loss_f1_score:0.07699402790111964 \n", "训练集f1_score:1.0,测试集f1_score:0.922697478395862,loss_f1_score:0.07730252160413797 \n", "训练集f1_score:1.0,测试集f1_score:0.9079606352786754,loss_f1_score:0.09203936472132457 \n", "训练集f1_score:1.0,测试集f1_score:0.9229248123974857,loss_f1_score:0.0770751876025143 \n", "训练集f1_score:1.0,测试集f1_score:0.923913432252704,loss_f1_score:0.07608656774729605 \n", "训练集f1_score:1.0,测试集f1_score:0.9257200990324236,loss_f1_score:0.07427990096757642 \n", "训练集f1_score:1.0,测试集f1_score:0.9276995504041144,loss_f1_score:0.07230044959588555 \n", "训练集f1_score:1.0,测试集f1_score:0.9251348482525271,loss_f1_score:0.07486515174747288 \n", "训练集f1_score:1.0,测试集f1_score:0.9231090610362633,loss_f1_score:0.07689093896373667 \n", "训练集f1_score:1.0,测试集f1_score:0.9164413618677342,loss_f1_score:0.08355863813226583 \n", "训练集f1_score:1.0,测试集f1_score:0.9293008018695311,loss_f1_score:0.07069919813046888 \n", "训练集f1_score:1.0,测试集f1_score:0.9301285411934597,loss_f1_score:0.06987145880654033 \n", "100%|█████████████████████████████████████████████| 100/100 [33:56<00:00, 20.36s/trial, best loss: 0.06779926514672352]\n", "Search best param: {'bagging_fraction': 0.7310343530671259, 'boosting': 'gbdt', 'feature_fraction': 0.8644701162989126, 'learning_rate': 0.0483933201073737, 'min_data_in_leaf': 15, 'num_boost_round': 1100, 'num_leaves': 60, 'objective': 'multiclass', 'max_depth': 7, 'num_threads': 8, 'is_unbalance': True, 'metric': 'None', 'train_metric': True, 'verbose': -1, 'bagging_freq': 5, 'num_class': 3, 'feature_pre_filter': False}\n" ] } ], "source": [ "##############超参数优化的超参域###################\n", "spaceParam = {\n", " 'boosting': hp.choice('boosting',['gbdt','dart']),\n", " 'learning_rate':hp.loguniform('learning_rate', np.log(0.01), np.log(0.05)),\n", " 'num_leaves': hp.quniform('num_leaves', 3, 66, 3), \n", " 'feature_fraction': hp.uniform('feature_fraction', 0.7,1),\n", " 'min_data_in_leaf': hp.quniform('min_data_in_leaf', 10, 50,5), \n", " 'num_boost_round':hp.quniform('num_boost_round',500,2000,100), \n", " 'bagging_fraction':hp.uniform('bagging_fraction',0.6,1) \n", "}\n", "# 超参数优化 ---------------------------------------------------------------------------------\n", "def getParam(param):\n", " for k in ['num_leaves', 'min_data_in_leaf','num_boost_round']:\n", " param[k] = int(float(param[k]))\n", " for k in ['learning_rate', 'feature_fraction','bagging_fraction']:\n", " param[k] = float(param[k])\n", " if param['boosting'] == 0:\n", " param['boosting'] = 'gbdt'\n", " elif param['boosting'] == 1:\n", " param['boosting'] = 'dart'\n", " # 添加固定参数\n", " param['objective'] = 'multiclass'\n", " param['max_depth'] = 7\n", " param['num_threads'] = 8\n", " param['is_unbalance'] = True\n", " param['metric'] = 'None'\n", " param['train_metric'] = True\n", " param['verbose'] = -1\n", " param['bagging_freq']=5\n", " param['num_class']=3 \n", " param['feature_pre_filter']=False\n", " return param\n", "def f1_score_eval(preds, valid_df):\n", " labels = valid_df.get_label()\n", " preds = np.argmax(preds.reshape(3, -1), axis=0)\n", " scores = f1_score(y_true=labels, y_pred=preds, average='macro')\n", " return 'f1_score', scores, True\n", "def lossFun(param):\n", " param = getParam(param)\n", " m = lgb.train(params=param,train_set=train_data,num_boost_round=param['num_boost_round'],\n", " valid_sets=[train_data,valid_data],valid_names=['train','valid'],\n", " feature_name=features,feval=f1_score_eval,\n", " early_stopping_rounds=earlyStopping,verbose_eval=False,keep_training_booster=True)\n", " train_f1_score = m.best_score['train']['f1_score']\n", " valid_f1_score = m.best_score['valid']['f1_score']\n", " loss_f1_score = 1 - valid_f1_score\n", " print('训练集f1_score:{},测试集f1_score:{},loss_f1_score:{}'.format(train_f1_score, valid_f1_score, loss_f1_score))\n", " return {'loss': loss_f1_score, 'params': param, 'status': STATUS_OK}\n", "\n", "features = model_feature\n", "train_data = lgb.Dataset(data=X_train[model_feature],label=y_train,feature_name=features)\n", "valid_data = lgb.Dataset(data=X_verify[features],label=y_verify,reference=train_data,feature_name=features)\n", "\n", "best_param = fmin(fn=lossFun, space=spaceParam, algo=tpe.suggest, max_evals=100, trials=Trials())\n", "best_param = getParam(best_param)\n", "print('Search best param:',best_param)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过特征选择和超参数优化后,最终的模型使用为将参数设置为贝叶斯优化之后的超参数,然后进行5折交叉,对测试集进行叠加求平均。" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Use 200 features ...\n", "the 1 training start ...\n", "Training until validation scores don't improve for 100 rounds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "D:\\SOFTWEAR_H\\Anaconda3\\lib\\site-packages\\lightgbm\\engine.py:151: UserWarning: Found `num_boost_round` in params. Will use it instead of argument\n", " warnings.warn(\"Found `{}` in params. Will use it instead of argument\".format(alias))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[100]\tvalid_0's f1_score: 0.894256\n", "[200]\tvalid_0's f1_score: 0.909942\n", "[300]\tvalid_0's f1_score: 0.913423\n", "[400]\tvalid_0's f1_score: 0.917897\n", "[500]\tvalid_0's f1_score: 0.920616\n", "Early stopping, best iteration is:\n", "[456]\tvalid_0's f1_score: 0.920717\n", "the 2 training start ...\n", "Training until validation scores don't improve for 100 rounds\n", "[100]\tvalid_0's f1_score: 0.918357\n", "[200]\tvalid_0's f1_score: 0.916436\n", "Early stopping, best iteration is:\n", "[140]\tvalid_0's f1_score: 0.92449\n", "the 3 training start ...\n", "Training until validation scores don't improve for 100 rounds\n", "[100]\tvalid_0's f1_score: 0.915242\n", "[200]\tvalid_0's f1_score: 0.927189\n", "[300]\tvalid_0's f1_score: 0.930614\n", "Early stopping, best iteration is:\n", "[238]\tvalid_0's f1_score: 0.930614\n", "the 4 training start ...\n", "Training until validation scores don't improve for 100 rounds\n", "[100]\tvalid_0's f1_score: 0.901683\n", "[200]\tvalid_0's f1_score: 0.912985\n", "[300]\tvalid_0's f1_score: 0.916988\n", "[400]\tvalid_0's f1_score: 0.92147\n", "[500]\tvalid_0's f1_score: 0.921353\n", "Early stopping, best iteration is:\n", "[411]\tvalid_0's f1_score: 0.922153\n", "the 5 training start ...\n", "Training until validation scores don't improve for 100 rounds\n", "[100]\tvalid_0's f1_score: 0.900975\n", "[200]\tvalid_0's f1_score: 0.908373\n", "[300]\tvalid_0's f1_score: 0.91384\n", "[400]\tvalid_0's f1_score: 0.917567\n", "Early stopping, best iteration is:\n", "[369]\tvalid_0's f1_score: 0.919843\n", " precision recall f1-score support\n", "\n", " 0 0.8726 0.9001 0.8861 1621\n", " 1 0.9569 0.8949 0.9249 1018\n", " 2 0.9586 0.9619 0.9603 4361\n", "\n", " accuracy 0.9379 7000\n", " macro avg 0.9294 0.9190 0.9238 7000\n", "weighted avg 0.9385 0.9379 0.9380 7000\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":88: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " test_['label'] = np.argmax(test_pred, axis=1)\n" ] } ], "source": [ "def f1_score_eval(preds, valid_df):\n", " labels = valid_df.get_label()\n", " preds = np.argmax(preds.reshape(3, -1), axis=0)\n", " scores = f1_score(y_true=labels, y_pred=preds, average='macro')\n", " return 'f1_score', scores, True\n", "\n", "def sub_on_line_lgb(train_, test_, pred, label, cate_cols, split,\n", " is_shuffle=True,\n", " use_cart=False,\n", " get_prob=False):\n", " n_class = 3\n", " train_pred = np.zeros((train_.shape[0], n_class))\n", " test_pred = np.zeros((test_.shape[0], n_class))\n", " n_splits = 5\n", "\n", " assert split in ['kf', 'skf'\n", " ], '{} Not Support this type of split way'.format(split)\n", "\n", " if split == 'kf':\n", " folds = KFold(n_splits=n_splits, shuffle=is_shuffle, random_state=1024)\n", " kf_way = folds.split(train_[pred])\n", " else:\n", " #与KFold最大的差异在于,他是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同。\n", " folds = StratifiedKFold(n_splits=n_splits,\n", " shuffle=is_shuffle,\n", " random_state=1024)\n", " kf_way = folds.split(train_[pred], train_[label])\n", "\n", " print('Use {} features ...'.format(len(pred)))\n", " #将以下参数改为贝叶斯优化之后的参数\n", " params = {\n", " 'learning_rate': 0.05,\n", " 'boosting_type': 'gbdt',\n", " 'objective': 'multiclass',\n", " 'metric': 'None',\n", " 'num_leaves': 60,\n", " 'feature_fraction':0.86,\n", " 'bagging_fraction': 0.73,\n", " 'bagging_freq': 5,\n", " 'seed': 1,\n", " 'bagging_seed': 1,\n", " 'feature_fraction_seed': 7,\n", " 'min_data_in_leaf': 15,\n", " 'num_class': n_class,\n", " 'nthread': 8,\n", " 'verbose': -1,\n", " 'num_boost_round': 1100,\n", " 'max_depth': 7,\n", " }\n", " for n_fold, (train_idx, valid_idx) in enumerate(kf_way, start=1):\n", " print('the {} training start ...'.format(n_fold))\n", " train_x, train_y = train_[pred].iloc[train_idx\n", " ], train_[label].iloc[train_idx]\n", " valid_x, valid_y = train_[pred].iloc[valid_idx\n", " ], train_[label].iloc[valid_idx]\n", "\n", " if use_cart:\n", " dtrain = lgb.Dataset(train_x,\n", " label=train_y,\n", " categorical_feature=cate_cols)\n", " dvalid = lgb.Dataset(valid_x,\n", " label=valid_y,\n", " categorical_feature=cate_cols)\n", " else:\n", " dtrain = lgb.Dataset(train_x, label=train_y)\n", " dvalid = lgb.Dataset(valid_x, label=valid_y)\n", "\n", " clf = lgb.train(params=params,\n", " train_set=dtrain,\n", "# num_boost_round=3000,\n", " valid_sets=[dvalid],\n", " early_stopping_rounds=100,\n", " verbose_eval=100,\n", " feval=f1_score_eval)\n", " train_pred[valid_idx] = clf.predict(valid_x,\n", " num_iteration=clf.best_iteration)\n", " test_pred += clf.predict(test_[pred],\n", " num_iteration=clf.best_iteration) / folds.n_splits\n", " print(classification_report(train_[label], np.argmax(train_pred,\n", " axis=1),\n", " digits=4))\n", " if get_prob:\n", " sub_probs = ['qyxs_prob_{}'.format(q) for q in ['围网', '刺网', '拖网']]\n", " prob_df = pd.DataFrame(test_pred, columns=sub_probs)\n", " prob_df['ID'] = test_['ID'].values\n", " return prob_df\n", " else:\n", " test_['label'] = np.argmax(test_pred, axis=1)\n", " return test_[['ID', 'label']]\n", "\n", "use_train = all_df[all_df['label'] != -1]\n", "use_test = all_df[all_df['label'] == -1]\n", "# use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]\n", "use_feats=model_feature\n", "sub = sub_on_line_lgb(use_train, use_test, use_feats, 'label', [], 'kf',is_shuffle=True,use_cart=False,get_prob=False)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }